{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we review the optimizers used in machine learning. \n",
    "# Gradient Descent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "Let's use a simple dataset of salaries from developers and machine learning engineers in five Chinese cities in 2019"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# developer salary in Beijing, Shanghai, Hangzhou, Shenzhen and Guangzhou in 2019\n",
    "x = [13854,12213,11009,10655,9503] \n",
    "x = np.reshape(x,newshape=(5,1)) / 10000.0\n",
    "\n",
    "# Machine Learning Engineer in the five cities.\n",
    "y =  [21332, 20162, 19138, 18621, 18016] \n",
    "y = np.reshape(y,newshape=(5,1)) / 10000.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions\n",
    "Objective Function:\n",
    "$$y=ax+b+ε$$\n",
    "Cost Function:\n",
    "$$J(a,b)=\\frac{1}{2n}\\sum_{i=0}^{n}(y_i−\\hat{y}_i )^2$$\n",
    "Optimization Function or optimizer:\n",
    "$$\\theta = \\theta - \\alpha \\frac{\\partial J}{\\partial \\theta}$$\n",
    "Here in the univariate linear regression:\n",
    "$$a = a - \\alpha \\frac{\\partial J}{\\partial a}$$\n",
    "$$b = b - \\alpha \\frac{\\partial J}{\\partial b}$$\n",
    "\n",
    "Here $\\frac{\\partial J}{\\partial a}$ and $\\frac{\\partial J}{\\partial b}$ are:\n",
    "\n",
    "$$ \\frac{\\partial J}{\\partial a} = \\frac{1}{n}\\sum_{i=0}^{n}x(\\hat{y}_i-y_i)$$\n",
    "\n",
    "\n",
    "$$ \\frac{\\partial J}{\\partial b} = \\frac{1}{n}\\sum_{i=0}^{n}(\\hat{y}_i-y_i)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(a, b, x):\n",
    "    return a*x + b\n",
    "\n",
    "def cost_function(a, b, x, y):\n",
    "    n = 5\n",
    "    return 0.5/n * (np.square(y-a*x-b)).sum()\n",
    "\n",
    "def sgd(a,b,x,y):\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    y_hat = model(a,b,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    a = a - alpha*da\n",
    "    b = b - alpha*db\n",
    "    return a, b\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_sgd(a,b,x,y,times):\n",
    "    for i in range(times):\n",
    "        a,b = sgd(a,b,x,y)\n",
    "\n",
    "    y_hat=model(a,b,x)\n",
    "    cost = cost_function(a, b, x, y)\n",
    "    print(a,b,cost)\n",
    "    plt.scatter(x,y)\n",
    "    plt.plot(x,y_hat)\n",
    "    plt.show()\n",
    "    return a,b, cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.950768563083351 0.8552812669346652 0.00035532090622957674\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXhU5f3+8fdDFhJCIEDYEghhDUtI2BFc6o64ooKtVq2iYv3Wn7VCRNS6FJcqWpdqpbRFal1aEhZxpS4origIkwRI2LcECFsSCFlnnt8foCINEGBmziz367q4JDOHOff1kNx+OHPmHGOtRUREgl8jpwOIiIh3qNBFREKECl1EJESo0EVEQoQKXUQkREQ6tePExESbmprq1O5FRILSkiVLdlprW9f3nGOFnpqayuLFi53avYhIUDLGbDzSczrkIiISIlToIiIhQoUuIhIiVOgiIiFChS4iEiJU6CIiIUKFLiISIlToIiJeVuv28OePVrO8uMyv+3Xsg0UiIqFo1fa9jJ/pIq+ojBq3hz5Jzf22bxW6iIgXuD2Wv322jj/9dxVNYyL5yy8HcGHf9n7NoEIXETlJ63bsY0K2i+82lTKiT1sevbwviU0b+z2HCl1E5AR5PJYZX27gyfkFREc04tmf9+OyfkkYYxzJo0IXETkBm3fvZ0K2i0Xrd3NWWmv+eGUGbZvFOJpJhS4ichystby2aBOPvbuSRsbw5JUZjBnUwbGp/FAqdBGRBiourWTirFw+W72T07ol8sToDJITYp2O9QMVuojIMVhryVmyhT+8tQK3tUwelc61Q1MCYio/lApdROQoSsqrmDQ7j48KShiS2pIpYzLo1CrO6Vj1UqGLiNTDWss8VzEPzltOZY2b+y/qxdhTO9OoUWBN5YdSoYuIHGbXvmrun5vPe/nb6NcxgaevyqRr66ZOxzomFbqIyCHez9/KfXPy2VtVx90XpDHu9C5ERgTHZa9U6CIiQOn+Gh6at5y5y4rpk9SM12/pR1q7eKdjHRcVuoiEvQUFJUyclcvuihruPLc7vzmrG1FBMpUfSoUuImGrvKqWR95ewczFW0hrG8/0GwaTnuy/qyN6mwpdRMLS56t3cneOi23lVdx2ZlfuPLc7jSMjnI51UlToIhJWKqrrePy9lbz69Sa6tI4j57bhDEhp4dV9zF1axJT5hRSXVpKUEEvWiDRG9U/26j7qo0IXkbCxaN0usnJy2bxnPzed1pmsEWnERHl3Kp+7tIhJs/OorHUDUFRayaTZeQA+L3UVuoiEvKpaN1PmFzL9i/V0bNGE/4wbxpDOLX2yrynzC38o8+9VHty/Cl1E5CQs3bSH8dku1u2o4LpTOnHPyJ7ENfZd9RWXVh7X496kQheRkFRd5+bZD1fz10/X0q5ZDK/eNJTTuif6fL9JCbEU1VPeSX64KmPwnWgpInIM+UVlXPrnL3jpk7WMHtiB9393hl/KHCBrRBqxhx2Xj42KIGtEms/3rQldREJGrdvDiwvW8MLHa2gZF830GwZxds+2fs3w/XFyneUiInKCCrftZXz2MvKLyhnVL4mHLu1DQpNoR7KM6p/slwI/nApdRIJandvDXxeu47kPVxMfE8nUawdyQXo7p2M5QoUuIkFr7Y59jJ/pYtnmUkamt+ORUem0atrY6ViOUaGLSNDxeCzTv1jPlPmFxEZH8PzV/bkko33A3RLO31ToIhJUNu6qICs7l2827Oacnm14/Iq+tGkW43SsgKBCF5Gg4PFYXlu0kcffKyDCGKaMzmD0wA5hP5UfSoUuIgGvqLSSu3NcfLFmF6d3T+SJKzP88kGdYKNCF5GAZa0le/EW/vD2CjzW8ujl6VwzJEVT+RGo0EUkIG0vr+KeWbksKNzB0M4teWpMJh1bNnE6VkA75kf/jTEdjTELjDErjTHLjTG/rWebnsaYr4wx1caYCb6JKiLhwFrL3KVFnP/MQr5at4sHL+nNG7ecojJvgIZM6HXAeGvtd8aYeGCJMeYDa+2KQ7bZDdwBjPJFSBEJDzv3VXPfnDzmL9/OgJQEnhqTSZfWTZ2OFTSOWejW2q3A1oO/32uMWQkkAysO2aYEKDHGXOSroCIS2t7N28r9c/PZV1XHpJE9ufn0LkQ00rHy43Fcx9CNMalAf2DRiezMGDMOGAeQkpJyIi8hIiFmT0UND85bzjxXMX2Tm/P0VZn0aBvvdKyg1OBCN8Y0BWYBd1pry09kZ9baacA0gEGDBtkTeQ0RCR0frtjOpDl57Kmo4a7zenDbmV2JitBVvU9UgwrdGBPFgTJ/zVo727eRRCTUlVfV8oe3VpCzZAs928Uz48bB9Elq7nSsoHfMQjcHTvj8B7DSWvsn30cSkVC2cNUOJs7KpWRvNbef1Y07zulOdKSmcm9oyIR+KnAdkGeMWXbwsXuBFABr7VRjTDtgMdAM8Bhj7gR6n+ihGREJPfuq63js3ZW8vmgT3do0Zfa1A8nsmOB0rJDSkLNcPgeO+laztXYb0MFboUQktHy1dhdZOS6KSisZd0YX7jqvBzGH3aZNTp4+KSoiPlNZ4+bJ+QW8/MUGOrVqwsxbhzE4taXTsUKWCl1EfGLJxj1MyHaxfmcFvxrWiYkje9IkWpXjS1pdEfGqqlo3z3y4ir8tXEf75rG8fvNQhndLdDpWWFChi4jX5G0p466Zy1hdso+rh3Tk3gt7ER8T5XSssKFCF5GTVlPn4YWPV/PiJ2tp3bQxM24czJlpbZyOFXZU6CJyUlZuLWf8TBcrtpZzxYBkHry4D82baCp3ggpdRE5IndvD1E/X8txHq2keG8W06wZyfp92TscKayp0ETlua0r2Mn6mC9eWMi7KaM/ky9JpGRftdKywp0IXkQZzeyzTP1/PlP8WEhcdwQvX9OfijCSnY8lBKnQRaZANOyuYkO1i8cY9nNe7LY9enk6b+BinY8khVOgiclQej+VfX2/kj+8VEBlh+NNVmVzeP1k3ag5AKnQROaLNu/czcVYuX67dxc96tOaJKzNo11xTeaBSoYvI/7DW8p9vNzP57QN3mnz8ir78YnBHTeUBToUuIj+xrayKe2bn8knhDoZ1acWTozPo2LKJ07GkAVToIgIcmMrnLC3ioXnLqXVbHr60D9ed0olGulFz0FChiwg79lZz75w8PlixnUGdWvDUmExSE+OcjiXHSYUuEubezi3m93Pzqahxc9+FvRh7WmciNJUHJRW6SJjaXVHD79/M553crWR2aM7TV2XSrU2807HkJKjQRcLQf5dv4945+ZRV1pA1Io1bz+hCZIRu1BzsVOgiYaSsspaH31rO7O+K6NW+Ga+MHULvpGZOxxIvUaGLhIlPCku4Z1YeO/ZVc8fZ3bj97O5ER2oqDyUqdJEQt6+6jkffWcEb32yme5umTLt+IBkdEpyOJT6gQhcJYV+u3UlWdi5byyq59Wdd+N25PYiJinA6lviICl0kBO2vqeOJ9wr451cb6ZwYR/avhzOwUwunY4mPqdBFQsziDbuZkO1iw6793DA8lYkX9CQ2WlN5OFChi4SIqlo3f/pgFX/7bB0dWsTy73GncEqXVk7HEj9SoYuEANfmUsZnu1hTso9rhqZw74W9aNpYP97hRn/jIkGsps7D8x+t5qVP19ImvjGvjB3CGT1aOx1LHKJCFwlSy4vLGD/TRcG2vYwe2IHfX9yb5rFRTscSB6nQRYJMrdvDS5+s5fmPVtMiLpq/Xz+Ic3u3dTqWBAAVukgQWbV9L+NnusgrKuPSzCQevrQPLeKinY4lAUKFLhLA5i4tYsr8QopKK2kWE0lFjZvmsVH85ZcDuLBve6fjSYBRoYsEqLlLi5g0O4/KWjcA5VV1NDLwu3O7q8ylXroyj0iAevL9gh/K/HseC1M/XedQIgl0KnSRALR5936Ky6rqfa64tNLPaSRY6JCLSACx1vL6N5t47J2VGMDWs01SQqy/Y0mQ0IQuEiCKSyu5fvo33Dcnn/4pLXjg4t7EHnZlxNioCLJGpDmUUAKdJnQRh1lryVmyhT+8vQK3xzJ5VDrXDk3BGEOLuGimzC+kuLSSpIRYskakMap/stORJUAds9CNMR2BV4B2gAeYZq197rBtDPAccCGwH7jBWvud9+OKhJaSvVXcOzuPD1eWMCS1JVPGZNCpVdwPz4/qn6wClwZryIReB4y31n5njIkHlhhjPrDWrjhkm5FA94O/hgIvHfyviNTDWstbuVt54M18Kmvc3H9RL8ae2plGjYzT0SSIHbPQrbVbga0Hf7/XGLMSSAYOLfTLgFestRb42hiTYIxpf/DPisghdu2r5vdv5vNu3jb6dUzg6asy6dq6qdOxJAQc1zF0Y0wq0B9YdNhTycDmQ77ecvAxFbrIId7P38Z9c/LYW1XH3RekMe70LkRG6NwE8Y4GF7oxpikwC7jTWlt++NP1/JH/OePKGDMOGAeQkpJyHDFFglvZ/loenJfP3GXF9Elqxuu39COtXbzTsSTENKjQjTFRHCjz16y1s+vZZAvQ8ZCvOwDFh29krZ0GTAMYNGhQfafYioScBQUlTJyVy+6KGn57TnduP7sbUZrKxQcacpaLAf4BrLTW/ukIm80DbjfG/JsDb4aW6fi5hLu9VbU88vZK/rN4M2lt45l+w2DSk5s7HUtCWEMm9FOB64A8Y8yyg4/dC6QAWGunAu9y4JTFNRw4bfFG70cVCR5frNnJ3Tm5bC2r5LYzu3Lnud1pHKkbNYtvNeQsl8+p/xj5odtY4DfeCiUSrCqq6/jjewX86+uNdGkdR85twxmQ0sLpWBIm9ElRES/5Zv1uJmS72LxnPzed1pmsEWnERGkqF/9RoYucpKpaN1PmFzL9i/V0bNGEf99yCkO7tHI6loQhFbrISVi6aQ/js12s21HBdad04p6RPYlrrB8rcYa+80ROQHWdm+c+XM3UT9fSrlkMr940lNO6JzodS8KcCl3kOOUXlTEh20XBtr1cNagD91/cm2YxUU7HElGhizRUrdvDiwvW8MLHa2gZF830GwZxds+2TscS+YEKXaQBCrftZXz2MvKLyhnVL4mHLu1DQpNop2OJ/IQKXeQo6twepn22jmc/WE18TCRTrx3IBentnI4lUi8VusgRrN2xj/EzXSzbXMrI9HY8MiqdVk0bOx1L5IhU6CKH8XgsL3+5gSffLyA2OoLnr+7PJRntOXBZI5HApUIXOcSmXfuZkOPim/W7OadnGx6/oi9tmsU4HUukQVToIhy4Jdyrizbx+LsriTCGKaMzGD2wg6ZyCSoqdAl7RaWVTMzJ5fM1Ozm9eyJPXJlBUkKs07FEjpsKXcKWtZbsxVuY/PYK3Nby6OXpXDMkRVO5BC0VuoSl7eVVTJqdx8cFJQzt3JIpozNJadXE6VgiJ0WFLmHFWss8VzEPvLmc6jo3D1zcmxuGp9KokaZyCX4qdAkbO/dVc/+cfN5fvo0BKQk8NSaTLq2bOh1LxGtU6BIW3svbyn1z89lXVcc9I3tyy+ldiNBULiFGhS4hrXR/DQ+8uZx5rmL6Jjfn6asy6dE23ulYIj6hQpeQ9dHK7dwzO489FTXcdV4PbjuzK1ERjZyOJeIzKnQJOeVVtfzhrRXkLNlCz3bxvHzDYNKTmzsdS8TnVOgSUhau2sHEWbmU7K3m9rO6ccc53YmO1FQu4UGFLiGhorqOx95dyWuLNtG1dRyzbhtOv44JTscS8SsVugS9r9ftIivHxZY9ldxyemfGn59GTFSE07FE/E6FLkGrssbNk/MLePmLDXRq1YSZtw5jcGpLp2OJOEaFLkFpycY9TMh2sX5nBb8a1omJI3vSJFrfzhLe9BMgQaWq1s0zH67ibwvX0b55LK/fPJTh3RKdjiUSEFToEjTytpRx18xlrC7Zxy8Gd+S+i3oRHxPldCyRgKFCl4BXU+fhhQVreHHBGhKbRvPyjYM5K62N07FEAo4KXQJawbZy7vqPixVby7mifzIPXtKH5k00lYvUR4UuAanO7eGvC9fx7IeraB4bxbTrBnJ+n3ZOxxIJaCp0CThrSvYyPjsX1+ZSLspoz+TL0mkZF+10LJGAp0KXgOH2WKZ/vp4p/y0kLjqCF67pz8UZSU7HEgkaKnTxmblLi5gyv5Di0kqSEmLJGpHGqP7J9W67YWcFWTkuvt2wh3N7teWxK9JpEx/j58QiwU2FLj4xd2kRk2bnUVnrBqCotJJJs/MAflLqHo/l1UUbefzdAiIjDE+PyeSKAcm6UbPICVChi09MmV/4Q5l/r7LWzZT5hT8U+pY9+7k7J5cv1+7ijB6teeLKvrRvHutEXJGQoEIXnygurTzi49Za/vPtZh55ZyXWWh6/oi+/GNxRU7nISVKhi08kJcRSVE+pt20Ww40zvuWTwh0M69KKJ0dn0LFlEwcSioQeXflffCJrRBqxh13CNirCUFZZy9frdvHwpX147eahKnMRLzrmhG6MmQ5cDJRYa9Preb4FMB3oClQBY621+d4OKsHl++PkU+YXUlRaSUxkI6rqPGR0aMZTYzLpnBjncEKR0NOQCX0GcMFRnr8XWGatzQCuB57zQi4JAaP6J3Pvhb1o0SQKD3Dfhb2YeeswlbmIjxxzQrfWLjTGpB5lk97A4we3LTDGpBpj2lprt3snogSjPRU1/P7NfN7O3Upmh+Y8fVUm3drEOx1LJKR5401RF3AF8LkxZgjQCegA/E+hG2PGAeMAUlJSvLBrCUQfrNjOpNl5lFXWMOH8Hvz6Z12JjNDbNSK+5o1C/yPwnDFmGZAHLAXq6tvQWjsNmAYwaNAg64V9SwApq6zl4beWM/u7Inq1b8YrY4fQO6mZ07FEwsZJF7q1thy4EcAcOJF4/cFfEkY+XbWDiTm57NhXzR1nd+P2s7sTHampXMSfTrrQjTEJwH5rbQ1wM7DwYMlLGNhXXcej76zgjW82061NU6ZdP5CMDglOxxIJSw05bfEN4Ewg0RizBXgQiAKw1k4FegGvGGPcwArgJp+llYDy5dqdZGXnUlxWya1ndOF35/Ug5rBzz0XEfxpylsvVx3j+K6C71xJJwNtfU8eT7xcy48sNpLZqQs6vhzGwU0unY4mEPX30X47L4g27mZDtYsOu/dwwPJWJF/QkNlpTuUggUKFLg1TVuvnTB6v422frSE6I5Y1bTmFY11ZOxxKRQ6jQ5Zhcm0sZn+1iTck+rhmawr0X9qJpY33riAQa/VTKEdXUeXj+o9W89OlaWjdtzD/HDuFnPVo7HUtEjkCFLvVaUVzOXTOXUbBtL1cO6MADl/SmeWyU07FE5ChU6PITtW4PL32yluc/Wk1Ck2j+dv0gzuvd1ulYItIAKnT5wartexk/00VeURmXZCbxh0v70CIu2ulYItJAKnTB7bH8/bN1PP3fVTSNieTFawZwUUZ7p2OJyHFSoYe59TsrmJDtYsnGPYzo05ZHRvWldXxjp2OJyAlQoYcpj8fyz6828MT7BURHNOLZn/fjsn5JulGzSBBToYehzbv3k5Xj4ut1u2kc2YjyqjqmzC8Efrx1nIgEHxV6GLHW8sY3m3n0nRXUeSxREYbqOg8ARaWVTJqdB6jURYKVLlgdJraWVfKrl7/l3jl59EtJIKFJFLXun95jpLLW/cOkLiLBR4Ue4qy1ZC/ezPnPLOTb9buZfFkf/jV2KCXl1fVuX1xa6eeEIuItOuQSwkr2VnHv7Dw+XFnC4NQWTBmdSWpiHABJCbEU1VPeSQmx/o4pIl6iCT0EWWuZ5yrm/GcW8tnqndx/US/+PW7YD2UOkDUijdjDbkYRGxVB1og0f8cVES/RhB5idu2r5vdv5vNu3jYyOybw9JhMurVp+j/bff/G55T5hRSXVpKUEEvWiDS9ISoSxFToIeT9/G3cPzePsspaskakcesZXYiMOPI/wkb1T1aBi4QQFXoIKNtfy0NvLWfO0iL6JDXj1ZuH0rNdM6djiYifqdCD3ILCEu6ZlcuufTX89pzu3H52N6KOMpWLSOhSoQepvVW1PPL2Sv6zeDM92jbl79cPpm+H5k7HEhEHqdCD0BdrdnJ3Ti5byyr59c+68rvzutM4UjdqFgl3KvQgUlFdxx/fK+BfX2+kS2IcObcNZ0BKC6djiUiAUKEHiW/W72ZCtovNe/Yz9tTOB84jj9ZULiI/UqEHuKpaN0/NL+QfX6ynQ4tY/n3LKQzt0srpWCISgFToAWzppj2Mz3axbkcF156SwqSRvYhrrL8yEamf2iEAVde5ee7D1Uz9dC3tmsXwr5uGcHr31k7HEpEAp0IPMPlFZUzIdlGwbS9XDerA/Rf3pllMlNOxRCQIqNADRK3bw4sL1vDCx2toERfN9BsGcXbPtk7HEpEgokIPAIXb9jI+exn5ReVc1i+Jhy/tQ0KTaKdjiUiQUaE7yO2xTFu4jmc+WEV8TCRTrx3ABentnY4lIkFKhe6QtTv2MSHbxdJNpYxMb8fkUekkNm3sdCwRCWIqdD/zeCwvf7mBJ98vICYqgud+0Y9LM5MwxjgdTUSCnArdjzbt2s+EHBffrN/N2T3b8PgVfWnbLMbpWCISIlTofmCt5bVFm3js3ZVEGMOTozMYM7CDpnIR8SoVuo8Vl1YycVYun63eyendE3niygzdiFlEfEKF7iPWWrKXbGHyWytwW8sjo9L55dAUTeUi4jMqdB8oKa9i0uw8PiooYUjnljw1OpOUVk2cjiUiIS4kCn3u0qKAuHu9tZZ5rmIeeHM5VbVuHri4NzcMT6VRI03lIuJ7xyx0Y8x04GKgxFqbXs/zzYFXgZSDr/eUtfZlbwc9krlLi5g0O4/KWjcARaWVTJqdB+DXUt+5r5r75+Tz/vJt9E9J4KkxmXRt3dRv+xcRacjdhGcAFxzl+d8AK6y1mcCZwNPGGL99bn3K/MIfyvx7lbVupswv9FcE3svbyohnFvJxQQkTL+hJzq+Hq8xFxO+OOaFbaxcaY1KPtgkQbw6829cU2A3UeSVdAxSXVh7X495Uur+GB+ct581lxaQnN+P1Mf1Iaxfv8/2KiNTHG8fQXwDmAcVAPPBza62nvg2NMeOAcQApKSle2DUkJcRSVE95+/rUwI8LtnPPrDx2V9Twu3N78H9ndSUqoiH/4BER8Q1vNNAIYBmQBPQDXjDGNKtvQ2vtNGvtIGvtoNatvXPDhqwRacRG/fTemrFREWSNSPPK6x+uvKqWrGwXY2cspmVcNHN/cyq/Pbe7ylxEHOeNCf1G4I/WWgusMcasB3oC33jhtY/p+zc+/XGWy2erdzAxJ5dt5VX85qyu3HFOdxpH6kbNIhIYvFHom4BzgM+MMW2BNGCdF163wUb1T/bpGS0V1XU8/t5KXv16E11axzHrtuH0T2nhs/2JiJyIhpy2+AYHzl5JNMZsAR4EogCstVOBycAMY0weYICJ1tqdPkvsZ4vW7SIrJ5fNe/Zzy+mdGX9+GjFRmspFJPA05CyXq4/xfDFwvtcSBYjKmgOnPr785XpSWjZh5q3DGJza0ulYIiJHFBKfFPW27zbtYcJMF+t2VnD9sE7cM7InTaK1VCIS2NRSh6iuc/PMB6uZtnAt7ZvH8trNQzm1W6LTsUREGkSFflDeljLGZy9j1fZ9/GJwR+67qBfxMVFOxxIRabCwL/SaOg8vLFjDiwvWkNg0mpdvHMxZaW2cjiUictzCutALtpUzfqaL5cXlXN4/mYcu6UPzJprKRSQ4hWWh17k9/HXhOp79cBXNY6P463UDGdGnndOxREROStgV+pqSfYzPduHaXMpFfdszeVQ6LeP8dnFIERGfCZtCd3ssL3+xninzC4mNjuDPV/fnkswkp2OJiHhNWBT6xl0VTMh28e2GPZzbqy2PXZFOm/gYp2OJiHhVSBe6x2N5bdFGHnu3gMgIw1NjMrlyQLJu1CwiISlkC33Lnv1MnJXLF2t2cXr3RJ4cnUH75r69RrqIiJNCrtCttcxcvJnJb6/EWstjl/fl6iEdNZWLSMgLqULfXl7FPbNyWVC4g1O6tGTK6Ew6tmzidCwREb8ImUL/bPUOfvPad9S4PTx0SW+uH5ZKo0aaykUkfIRMoae2iqNfSgsevrQPnRPjnI4jIuJ3IVPoHVs24ZWxQ5yOISLiGN3ZWEQkRKjQRURChApdRCREqNBFREKECl1EJESo0EVEQoQKXUQkRKjQRURChLHWOrNjY3YAGx3Zue8lAjudDhEgtBY/0lr8SGvxU8ezHp2sta3re8KxQg9lxpjF1tpBTucIBFqLH2ktfqS1+ClvrYcOuYiIhAgVuohIiFCh+8Y0pwMEEK3Fj7QWP9Ja/JRX1kPH0EVEQoQmdBGREKFCFxEJESr0E2SMmW6MKTHG5B/heWOMed4Ys8YYk2uMGeDvjP7SgLXoaYz5yhhTbYyZ4O98/tSAtfjlwe+HXGPMl8aYTH9n9KcGrMdlB9dimTFmsTHmNH9n9JdjrcUh2w02xriNMaOPdx8q9BM3A7jgKM+PBLof/DUOeMkPmZwyg6OvxW7gDuApv6Rx1gyOvhbrgZ9ZazOAyYT+m4MzOPp6fARkWmv7AWOBv/sjlENmcPS1wBgTATwBzD+RHajQT5C1diEHiupILgNesQd8DSQYY9r7J51/HWstrLUl1tpvgVr/pXJGA9biS2vtnoNffg108EswhzRgPfbZH8/MiANC9iyNBnQGwP8DZgElJ7IPFbrvJAObD/l6y8HHRL53E/Ce0yGcZoy53BhTALzDgSk9LBljkoHLgakn+hoqdN8x9TwWstOHHB9jzFkcKPSJTmdxmrV2jrW2JzCKA4ehwtWzwERrrftEXyDSi2Hkp7YAHQ/5ugNQ7FAWCSDGmAwOHCseaa3d5XSeQGGtXWiM6WqMSbTWhuOFuwYB/zbGwIGLdV1ojKmz1s5t6AtoQvedecD1B892OQUos9ZudTqUOMsYkwLMBq6z1q5yOo/TjDHdzMEGO3gmWDQQlv+Ts9Z2ttamWmtTgRzg/46nzEET+gkzxrwBnAkkGmO2AA8CUQDW2qnAu8CFwBpgP3CjM0l971hrYYxpBywGmgEeY8ydQG9rbblDkX2mAd8XDwCtgL8c7LG6UL7qYAPW40oODD61QCXw80PeJA0pDViLk99HiK6diEjY0SEXEZEQoUIXEXwiXsEAAAAkSURBVAkRKnQRkRChQhcRCREqdBGREKFCFxEJESp0EZEQ8f8BY85ggZiYX+QAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "a=0\n",
    "b=0\n",
    "_, _, sgd_cost = iterate_sgd(a,b,x,y,100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.00035532090622957674"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgd_cost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After 100 iterations, the regression is almost done. We record the cost such that in the following exploration of other optimizers, we will be able to compare iterations to reach the same cost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate(a, b, x, y, target_cost, func):\n",
    "    i=0\n",
    "    for i in range(1000):\n",
    "        a,b = func(a,b,x,y)\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is just a small test see if it works reversely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate(a,b,x,y, sgd_cost, sgd)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Momentum\n",
    "Momentum was proposed by Boris Polyak in 1964. Compared to SGD, it adds another term to the parameter $\\theta$. This term is momentum, denoted as $m$. It takes previous momentum, multiplied by $\\beta$. When the derivate is in one direction, the momentum gets larger and larger and the optimum is reached sooner. While the derivative direction is changed, the momentum gets smaller, but it gets larger again when the direction is fixed.\n",
    "\n",
    "$$ m = \\beta m - \\alpha \\frac{\\partial J}{\\partial \\theta}$$\n",
    "\n",
    "$$ \\theta = \\theta + m $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def momentum(a, b, ma, mb, x, y):\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    beta = 0.9\n",
    "    y_hat = model(a,b,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    ma = beta*ma - alpha*da\n",
    "    mb = beta*mb - alpha*db\n",
    "    a = a + ma\n",
    "    b = b + mb\n",
    "    return a, b, ma, mb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_momentum(x, y, target_cost):\n",
    "    a=0\n",
    "    b=0\n",
    "    ma=0\n",
    "    mb=0\n",
    "    for i in range(1000):\n",
    "        a, b, ma, mb = momentum(a,b, ma, mb, x,y)\n",
    "        print(f\"{ma}\\t{mb}\")\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        print(f\"{ma}\\t{mb}\\t{cost}\")\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.22441501579999998\t0.19453800000000002\n",
      "0.37422247273264564\t0.32448006197140566\n",
      "0.42205082627737606\t0.3661434135397697\n",
      "0.3669569357577927\t0.31871477461500164\n",
      "0.23200196838148313\t0.2021526955988872\n",
      "0.056494030025446756\t0.050474753607803874\n",
      "-0.11476727348134302\t-0.09754962817390435\n",
      "-0.24244627038078237\t-0.20787942869918896\n",
      "-0.3012619304723101\t-0.25863596662407884\n",
      "-0.28445487409046366\t-0.24396840343676746\n",
      "-0.20350545891969496\t-0.17380977569712325\n",
      "-0.08364332123714681\t-0.0699911702901115\n",
      "0.04338773008400103\t0.040019175300583965\n",
      "0.147354432709795\t0.13006006211489604\n",
      "0.20640546147925654\t0.1812234868328625\n",
      "0.21130873793163232\t0.1855214000311385\n",
      "0.16633382496141397\t0.146649293292915\n",
      "0.08690993480643269\t0.0779595676235392\n",
      "-0.00507401470797135\t-0.0016055486586756679\n",
      "-0.08699979992056481\t-0.07247278613120942\n",
      "-0.1408466468646064\t-0.11904732814606164\n",
      "-0.15691741066134915\t-0.13293724917152472\n",
      "-0.1352587212815897\t-0.1141824310137057\n",
      "-0.08467582859413841\t-0.07040205626263701\n",
      "-0.01981150872173977\t-0.014266840612897652\n",
      "0.0428388633316896\t0.03994932131351764\n",
      "0.0889440734138505\t0.07984525590808778\n",
      "0.10944946511171101\t0.09758582125685551\n",
      "0.10215246324062663\t0.09126528657265356\n",
      "0.07152894890411822\t0.06475708853737841\n",
      "0.02702573023397084\t0.026236225728736323\n",
      "-0.01963087668527036\t-0.014149754660337372\n",
      "-0.05738680603496786\t-0.04683505435206044\n",
      "-0.0783606725095876\t-0.06499936572619444\n",
      "-0.07935720748013644\t-0.06587751992946814\n",
      "-0.06214086021009116\t-0.05099624589363139\n",
      "-0.03252999577624324\t-0.02539033468548647\n",
      "0.0013599704310883268\t0.003917662426907619\n",
      "0.031231311028720562\t0.029747420490065724\n",
      "0.05054959967538386\t0.04644447498706586\n",
      "0.055885184899792356\t0.051041064960017556\n",
      "0.047399341455049274\t0.04367682409456289\n",
      "0.028447708284948785\t0.02725561708852082\n",
      "0.0044813976566155725\t0.006494616802269353\n",
      "-0.018428741728553195\t-0.013352721762541375\n",
      "-0.035064239440853656\t-0.027770553086432857\n",
      "-0.04218591581428813\t-0.033955812608976274\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "46"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate_momentum(x, y, target_cost=sgd_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Voila, we achieve the same cost in just 46 iterations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nesterov Accelerated Gradient\n",
    "One small variant to momentum optimization, proposed by Yurii Nesterov in\n",
    "1983, is almost always faster than vanilla momentum optimization. The\n",
    "Nesterov Accelerated Gradient (NAG) method, also known as Nesterov\n",
    "momentum optimization, measures the gradient of the cost function not at the\n",
    "local position θ but slightly ahead in the direction of the momentum, at θ + βm\n",
    ".\n",
    "\n",
    "Function:\n",
    "$$ m = \\beta m - \\alpha \\frac{\\partial J(\\theta+\\beta m)}{\\partial \\theta}$$\n",
    "\n",
    "$$ \\theta = \\theta + m $$\n",
    "\n",
    "![Nesterov](images/nesterov.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For linear regression:\n",
    "\n",
    "$$ \\frac{\\partial J}{\\partial a} = \\frac{1}{n}\\sum_{i=0}^{n}x(\\hat{y}_i-y_i)$$\n",
    "\n",
    "$$ \\frac{\\partial J}{\\partial b} = \\frac{1}{n}\\sum_{i=0}^{n}(\\hat{y}_i-y_i)$$\n",
    "Here:\n",
    "$$\\hat{y}=(a+m_a)x+(b+m_b)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nesterov(a, b, ma, mb, x, y):\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    beta = 0.9\n",
    "    # to make it nesterov\n",
    "    # only modify here\n",
    "    y_hat = model(a+ma,b+mb,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    ma = beta*ma - alpha*da\n",
    "    mb = beta*mb - alpha*db\n",
    "    a = a + ma\n",
    "    b = b + mb\n",
    "    return a, b, ma, mb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_nesterov(x, y, target_cost):\n",
    "    a=0\n",
    "    b=0\n",
    "    ma=0\n",
    "    mb=0\n",
    "    for i in range(1000):\n",
    "        a, b, ma, mb = nesterov(a,b, ma, mb, x,y)\n",
    "        print(ma, mb)\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.22441501579999998 0.19453800000000002\n",
      "0.3220564154452913 0.2793379239428112\n",
      "0.3123373193267518 0.2712021012049902\n",
      "0.23316116137430112 0.2029331301200368\n",
      "0.12597323633919164 0.11039828486837797\n",
      "0.02495663925447264 0.02318026316585743\n",
      "-0.04849493913806344 -0.0402057492598321\n",
      "-0.08649709799457425 -0.07293501030570788\n",
      "-0.09201745463834532 -0.07757373717860767\n",
      "-0.0745806082761438 -0.062362386777225415\n",
      "-0.04587718632305795 -0.03741593563963194\n",
      "-0.016328775542062378 -0.011751334705692\n",
      "0.006910985026297077 0.008442254305498949\n",
      "0.020532049949227454 0.020301604537172682\n",
      "0.024559560846665853 0.02384948551560824\n",
      "0.02123968741689477 0.021030536818391725\n",
      "0.013779790759404111 0.014621078911301263\n",
      "0.005283948295620864 0.0073079790564409405\n",
      "-0.001933819282152938 0.0010943603685497844\n",
      "-0.006624599047500719 -0.0029384067739360196\n",
      "-0.008540833610722172 -0.004575531028581929\n",
      "-0.008161152003470043 -0.004230677649705485\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "21"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate_nesterov(x, y, target_cost=sgd_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Well, it took 21 iterations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AdaGrad\n",
    "\n",
    "Consider the elongated bowl problem again: Gradient Descent starts by quickly\n",
    "going down the steepest slope, which does not point straight toward the global\n",
    "optimum, then it very slowly goes down to the bottom of the valley. It would be\n",
    "nice if the algorithm could correct its direction earlier to point a bit more toward\n",
    "the global optimum. The AdaGrad algorithm achieves this correction by\n",
    "scaling down the gradient vector along the steepest dimensions.\n",
    "\n",
    "$$\\epsilon=1e-10$$\n",
    "$$ s = s + \\frac{\\partial J}{\\partial \\theta} \\odot \\frac{\\partial J}{\\partial \\theta} $$\n",
    "$$ \\theta = \\theta - \\alpha \\frac{\\partial J}{\\partial \\theta} \\oslash \\sqrt{s+\\epsilon} $$\n",
    "\n",
    "![AdaGrad](images/adaGrad.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ada_grad(a,b,sa, sb, x,y):\n",
    "    epsilon=1e-10\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    y_hat = model(a,b,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    sa=sa+da*da + epsilon\n",
    "    sb=sb+db*db + epsilon\n",
    "    a = a - alpha*da / np.sqrt(sa)\n",
    "    b = b - alpha*db / np.sqrt(sb)\n",
    "    return a, b, sa, sb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_ada_grad(x, y, target_cost):\n",
    "    a=0\n",
    "    b=0\n",
    "    sa=0\n",
    "    sb=0\n",
    "    for i in range(1000):\n",
    "        a, b, sa, sb = ada_grad(a,b, sa, sb, x,y)\n",
    "        print(f\"{sa}\\t{sb}\")\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.036209931751423 3.7845033445000005\n",
      "9.022051275308753 6.780559696352509\n",
      "12.37760305876143 9.303433839935428\n",
      "15.277271622803307 11.48402050624463\n",
      "17.82104257560732 13.397362865404352\n",
      "20.075029302440374 15.093083234334934\n",
      "22.086658442962953 16.606779524600782\n",
      "23.891795301524446 17.965372780402355\n",
      "25.518582625798217 19.189988154097705\n",
      "26.989713284113275 20.297660542873828\n",
      "28.323871617739783 21.30241632396213\n",
      "29.536696895192907 22.215996430588216\n",
      "30.641454625769647 23.048360177403552\n",
      "31.64952032868011 23.808048334075067\n",
      "32.570737998208436 24.502452158091806\n",
      "33.41369201864302 25.138017471314626\n",
      "34.1859175911238 25.72040258936349\n",
      "34.89406641053884 26.254602665813465\n",
      "35.5440390853296 26.745049076615274\n",
      "36.14109238402617 27.195689911681644\n",
      "36.68992711613361 27.610055932293786\n",
      "37.19476089865728 27.991315184952946\n",
      "37.65938897276373 28.342318646661948\n",
      "38.08723546173832 28.665638696230218\n",
      "38.48139690171807 28.96360178615045\n",
      "38.84467946537304 29.238316380907396\n",
      "39.179630992211976 29.491696997550104\n",
      "39.4885687078732 29.725485010754486\n",
      "39.773603338143715 29.94126675204925\n",
      "40.036660187143156 30.140489330578298\n",
      "40.27949764286802 30.324474523037136\n",
      "40.50372348972214 30.494431017700293\n",
      "40.71080934136093 30.651465247698887\n",
      "40.90210345416155 30.79659100891801\n",
      "41.078842138921175 30.930738025829122\n",
      "41.24215995373791 31.054759602568705\n",
      "41.393098832737515 31.169439475341107\n",
      "41.53261628206476 31.275497964778044\n",
      "41.66159275534983 31.373597512470813\n",
      "41.78083830489881 31.464347673911906\n",
      "41.89109859152528 31.548309630076606\n",
      "41.99306032474861 31.626000271475725\n",
      "42.087356195645654 31.6978959014265\n",
      "42.17456935664612 31.76443559928711\n",
      "42.25523749575812 31.82602427929392\n",
      "42.329856546897744 31.883035476277904\n",
      "42.39888407301105 31.935813885795167\n",
      "42.46274235438478 31.984677682985065\n",
      "42.52182121083374 32.029920641686175\n",
      "42.57648058323726 32.071814072927175\n",
      "42.6270528971003 32.110608599810575\n",
      "42.67384522837379 32.14653578397523\n",
      "42.71714128963229 32.17980961722002\n",
      "42.75720325283092 32.210627890463314\n",
      "42.79427342321258 32.239173450973354\n",
      "42.82857577747798 32.26561535771065\n",
      "42.86031737804119 32.29010994365495\n",
      "42.889689674048164 32.31280179313012\n",
      "42.91686969881755 32.3338246413763\n",
      "42.94202117245539 32.35330220293709\n",
      "42.96529551758434 32.37134893482122\n",
      "42.986832795402066 32.38807073985339\n",
      "43.00676256863235 32.40356561513984\n",
      "43.025204697347036 32.417924250135606\n",
      "43.04227007311023 32.43123057840443\n",
      "43.058061296420725 32.4435622868061\n",
      "43.07267330199957 32.45499128552349\n",
      "43.086193936081116 32.46558414205043\n",
      "43.09870448951406 32.475402481997115\n",
      "43.110280190159735 32.48450335933047\n",
      "43.12099065778465 32.49293959844891\n",
      "43.130900324380626 32.50076011029321\n",
      "43.14006882260533 32.50801018451441\n",
      "43.14855134481713 32.51473175955591\n",
      "43.15639897497798 32.52096367235614\n",
      "43.163658995515654 32.52674188924168\n",
      "43.17037517106928 32.53209971945499\n",
      "43.17658801088964 32.53706801264641\n",
      "43.182335011525204 32.54167534155483\n",
      "43.18765088129709 32.54594817100533\n",
      "43.19256774794783 32.54991101426352\n",
      "43.19711535074128 32.55358657770543\n",
      "43.20132121819138 32.556995894687155\n",
      "43.205210832506516 32.560158449430105\n",
      "43.20880778175224 32.563092291674785\n",
      "43.21213390065799 32.56581414279806\n",
      "43.2152094009225 32.56833949403565\n",
      "43.218052991807106 32.57068269740245\n",
      "43.220681991746005 32.57285704985823\n",
      "43.22311243164701 32.574874871224374\n",
      "43.22535915050512 32.576747576319114\n",
      "43.22743588390416 32.578485741743236\n",
      "43.22935534593807 32.58009916771559\n",
      "43.2311293050434 32.58159693532748\n",
      "43.23276865419746 32.58298745955748\n",
      "43.2342834759022 32.584278538362064\n",
      "43.23568310234268 32.58547739813427\n",
      "43.23697617107932 32.58659073580031\n",
      "43.238170676606536 32.58762475780386\n",
      "43.23927401808534 32.58858521620936\n",
      "43.24029304353432 32.5894774421379\n",
      "43.24123409074243 32.59030637673369\n",
      "43.24210302514704 32.59107659984421\n",
      "43.242905274902746 32.59179235658335\n",
      "43.243645863349556 32.59245758193453\n",
      "43.24432943907341 32.593075923538784\n",
      "43.244960303737784 32.5936507628022\n",
      "43.24554243785174 32.5941852344471\n",
      "43.24607952462745 32.59468224462204\n",
      "43.24657497206891 32.59514448767715\n",
      "43.24703193342299 32.595574461703606\n",
      "43.24745332611418 32.59597448292836\n",
      "43.24784184927554 32.59634669904891\n",
      "43.248199999979725 32.596693101586226\n",
      "43.2485300882666 32.59701553732847\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "114"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate_ada_grad(x, y, target_cost=sgd_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see that AdaGrad is slower even than SGD. \n",
    "\n",
    "# RMSProp\n",
    "\n",
    "As we’ve seen, AdaGrad runs the risk of slowing down a bit too fast and never\n",
    "converging to the global optimum. The RMSProp algorithm fixes this by\n",
    "accumulating only the gradients from the most recent iterations (as opposed to\n",
    "all the gradients since the beginning of training). It does so by using exponential\n",
    "decay in the first step.\n",
    "\n",
    "\n",
    "$$\\epsilon=1e-10$$\n",
    "$$ s = \\beta s + (1-\\beta) \\frac{\\partial J}{\\partial \\theta} \\odot \\frac{\\partial J}{\\partial \\theta} $$\n",
    "$$ \\theta = \\theta - \\alpha \\frac{\\partial J}{\\partial \\theta} \\oslash \\sqrt{s+\\epsilon} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmsprop(a,b,sa, sb, x,y):\n",
    "    epsilon=1e-10\n",
    "    beta = 0.9\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    y_hat = model(a,b,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    sa=beta*sa+(1-beta)*da*da + epsilon\n",
    "    sb=beta*sb+(1-beta)*db*db + epsilon\n",
    "    a = a - alpha*da / np.sqrt(sa)\n",
    "    b = b - alpha*db / np.sqrt(sb)\n",
    "    return a, b, sa, sb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_rmsprop(x, y, target_cost):\n",
    "    a=0\n",
    "    b=0\n",
    "    sa=0\n",
    "    sb=0\n",
    "    for i in range(1000):\n",
    "        a, b, sa, sb = rmsprop(a,b, sa, sb, x,y)\n",
    "        print(sa, sb)\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5036209932651422 0.37845033453999993\n",
      "0.6666748340963962 0.5011779498389735\n",
      "0.7035624038659631 0.5290949609920477\n",
      "0.6846116014928423 0.5150052874480152\n",
      "0.6413660413519212 0.48260564824209634\n",
      "0.5892427995404017 0.44348725811529754\n",
      "0.5358164308595106 0.40335218230125897\n",
      "0.48463120243995755 0.3648742872141607\n",
      "0.4371551733415317 0.32916597779828294\n",
      "0.39382062269916923 0.2965594306913857\n",
      "0.3545748318305831 0.26702044359109756\n",
      "0.3191617858943141 0.24036036007818096\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "11"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate_rmsprop(x, y, target_cost=sgd_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Wa, that is only 11 times. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adam\n",
    "\n",
    "Adam, which stands for adaptive moment estimation, combines the ideas of\n",
    "momentum optimization and RMSProp. \n",
    "\n",
    "\n",
    "$$ m = \\beta_1 m - (1-\\beta_1)\\frac{\\partial J}{\\partial \\theta}$$\n",
    "$$ s = \\beta_2 s + (1-\\beta_2) \\frac{\\partial J}{\\partial \\theta} \\odot \\frac{\\partial J}{\\partial \\theta} $$\n",
    "$$ \\hat{m} = \\frac{m}{1-\\beta_1^T} $$\n",
    "$$ \\hat{s} = \\frac{s}{1-\\beta_2^T} $$\n",
    "$$ \\theta = \\theta + \\alpha \\hat{m} \\oslash \\sqrt{\\hat{s}+\\epsilon} $$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adam(a, b, ma, mb, sa, sb, t, x, y):\n",
    "    epsilon=1e-10\n",
    "    beta1 = 0.9\n",
    "    beta2 = 0.9\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    y_hat = model(a,b,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    ma = beta1 * ma - (1-beta1)*da\n",
    "    mb = beta1 * mb - (1-beta1)*db\n",
    "    sa = beta2 * sa + (1-beta2)*da*da\n",
    "    sb = beta2 * sb + (1-beta2)*db*db\n",
    "    ma_hat = ma/(1-beta1**t)\n",
    "    mb_hat = mb/(1-beta1**t)\n",
    "    sa_hat=sa/(1-beta2**t)\n",
    "    sb_hat=sb/(1-beta2**t)\n",
    "    \n",
    "    a = a + alpha*ma_hat / np.sqrt(sa_hat)\n",
    "    b = b + alpha*mb_hat / np.sqrt(sb_hat)\n",
    "    \n",
    "    return a, b, ma, mb, sa, sb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_adam(x, y, target_cost):\n",
    "    a=0\n",
    "    b=0\n",
    "    ma=0\n",
    "    mb=0\n",
    "    sa=0\n",
    "    sb=0\n",
    "    for i in range(1000):\n",
    "        a, b, ma, mb, sa, sb = adam(a,b, ma, mb, sa, sb, i+1, x, y)\n",
    "        print(f\"{ma}\\t{mb}\\t{sa}\\t{sb}\")\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.22441501579999992\t0.19453799999999996\t0.5036209931651422\t0.3784503344399999\n",
      "0.4016192340199999\t0.3481753999999999\t0.8518430281932292\t0.6402109361703999\n",
      "0.5363759738971576\t0.46503883320196493\t1.07262411300758\t0.8062610188683121\n",
      "0.6330149656607605\t0.5488792785222582\t1.1911922341823753\t0.895531357068881\n",
      "0.6954922736152254\t0.603123968069419\t1.2302760874303023\t0.925077503166402\n",
      "0.7274510180595801\t0.6309289839343443\t1.2102871620886997\t0.9102165370096031\n",
      "0.7322854961686895\t0.6352347958137915\t1.1494443580778129\t0.8646207447731215\n",
      "0.7132116293046631\t0.6188272743886445\t1.0638272189106417\t0.8003578054256075\n",
      "0.6733462945505432\t0.5844063952845583\t0.9673391882821758\t0.7278635560226704\n",
      "0.6157957995957747\t0.5346628581915572\t0.8715625623331551\t0.6558335963281313\n",
      "0.5437474965633194\t0.4723574225921566\t0.7855022477276057\t0.5910315423836943\n",
      "0.4605465905355612\t0.4003874237356603\t0.7152614958663444\t0.5380462226395261\n",
      "0.369725397655793\t0.3218121173778424\t0.6637757717878439\t0.49909226801269685\n",
      "0.274947831152388\t0.23980462538427427\t0.6308124057706669\t0.47400962325270635\n",
      "0.17985640530446415\t0.15751947629598656\t0.6134242262809029\t0.46060302566421035\n",
      "0.08785962237036239\t0.07790852397792443\t0.6068582956517541\t0.45532244789580134\n",
      "0.001933592815800661\t0.003549186015826572\t0.6056783659441023\t0.4541038358094726\n",
      "-0.07550214246825582\t-0.06346307358846585\t0.6047743758547716\t0.45312546332404074\n",
      "-0.14262068152365687\t-0.12154597442587403\t0.6000511654663822\t0.4493241456796001\n",
      "-0.198158083452391\t-0.16960472211934757\t0.5887657091557674\t0.4406482004363841\n",
      "-0.2413503318769978\t-0.2069774262424965\t0.5695892904193119\t0.4261043208993387\n",
      "-0.2718627632509636\t-0.23337399319470653\t0.5424938152075657\t0.40567262875423904\n",
      "-0.28972757335619564\t-0.2488220132766666\t0.5085404375723384\t0.38014845346026394\n",
      "-0.29529721236352313\t-0.25362640497715555\t0.4696181652662369\t0.3509465461704309\n",
      "-0.2892167641183978\t-0.2483454971123596\t0.4281550327777841\t0.3198846514087885\n",
      "-0.27241533612699387\t-0.2337835718350304\t0.3868085337177365\t0.2889514543955107\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "25"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate_adam(x, y, target_cost=sgd_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "25 times"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NAdam\n",
    "\n",
    "Adam, which stands for adaptive moment estimation, combines the ideas of\n",
    "momentum optimization and RMSProp. \n",
    "\n",
    "\n",
    "$$ m = \\beta_1 m - (1-\\beta_1)\\frac{\\partial J(\\theta+\\beta_1 m)}{\\partial \\theta}$$\n",
    "$$ s = \\beta_2 s + (1-\\beta_2) \\frac{\\partial J(\\theta+\\beta_1 m)}{\\partial \\theta} \\odot \\frac{\\partial J(\\theta+\\beta_1 m)}{\\partial \\theta} $$\n",
    "$$ \\hat{m} = \\frac{m}{1-\\beta_1^T} $$\n",
    "$$ \\hat{s} = \\frac{s}{1-\\beta_2^T} $$\n",
    "$$ \\theta = \\theta + \\alpha \\hat{m} \\oslash \\sqrt{\\hat{s}+\\epsilon} $$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nadam(a, b, ma, mb, sa, sb, t, x, y):\n",
    "    epsilon=1e-10\n",
    "    beta1 = 0.9\n",
    "    beta2 = 0.9\n",
    "    n = 5\n",
    "    alpha = 1e-1\n",
    "    # to modify adam to nadam, \n",
    "    # we only modify here\n",
    "    # with a = a + ma \n",
    "    # and b = b + mb\n",
    "    y_hat = model(a+ma,b+mb,x)\n",
    "    da = (1.0/n) * ((y_hat-y)*x).sum()\n",
    "    db = (1.0/n) * ((y_hat-y).sum())\n",
    "    ma = beta1 * ma - (1-beta1)*da\n",
    "    mb = beta1 * mb - (1-beta1)*db\n",
    "    sa = beta2 * sa + (1-beta2)*da*da\n",
    "    sb = beta2 * sb + (1-beta2)*db*db\n",
    "    ma_hat = ma/(1-beta1**t)\n",
    "    mb_hat = mb/(1-beta1**t)\n",
    "    sa_hat=sa/(1-beta2**t)\n",
    "    sb_hat=sb/(1-beta2**t)\n",
    "    \n",
    "    a = a + alpha*ma_hat / np.sqrt(sa_hat)\n",
    "    b = b + alpha*mb_hat / np.sqrt(sb_hat)\n",
    "    \n",
    "    return a, b, ma, mb, sa, sb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate_nadam(x, y, target_cost):\n",
    "    a=0\n",
    "    b=0\n",
    "    ma=0\n",
    "    mb=0\n",
    "    sa=0\n",
    "    sb=0\n",
    "    for i in range(1000):\n",
    "        a, b, ma, mb, sa, sb = nadam(a,b, ma, mb, sa, sb, i+1, x, y)\n",
    "        print(f\"{ma}\\t{mb}\\t{sa}\\t{sb}\")\n",
    "        cost = cost_function(a, b, x, y)\n",
    "        if cost<target_cost:\n",
    "            break\n",
    "    return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.22441501579999992\t0.19453799999999996\t0.5036209931651422\t0.3784503344399999\n",
      "0.34945317673264553\t0.30303326197140557\t0.6707614023970666\t0.5043149255896258\n",
      "0.4086630974407369\t0.3545220463126516\t0.6923373513038595\t0.5207829264938997\n",
      "0.42479852013231584\t0.36869895522307994\t0.655595591180131\t0.4933351229538528\n",
      "0.41338276013846714\t0.3590014409177542\t0.5996858101920048\t0.45138499366847756\n",
      "0.3850680115669883\t0.33465143621864796\t0.5414133518432255\t0.40758055150153066\n",
      "0.34719317806152206\t0.30200342133598246\t0.487276010490025\t0.36682917334514187\n",
      "0.3048126007276661\t0.26543462733685447\t0.43913535841653084\t0.3305518278022634\n",
      "0.2613794636593488\t0.22793641452525237\t0.3968993337519565\t0.2986967105147332\n",
      "0.21920663475537913\t0.19151305757275972\t0.35978057495684657\t0.2706847309093604\n",
      "0.17978502825765205\t0.15745724253108329\t0.32686534752782104\t0.24583770178851216\n",
      "0.1440100206019037\t0.12654600408884117\t0.29734596861667506\t0.22355385981586914\n",
      "0.11234680538118677\t0.09918382962474853\t0.27059121178700213\t0.20336160118017668\n",
      "0.08495332320331386\t0.0755090552163363\t0.24614315931263117\t0.1849178241182976\n",
      "0.061772280422955676\t0.05547351561060023\t0.22368554429852547\t0.16798470258860718\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "14"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterate_nadam(x, y, target_cost=sgd_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "well, it takes only 14 times."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
